from datasets import load_dataset

DEFAULT_PROMPT_TEMPLATE = "\n\nHuman:\n{raw_prompt}\n\nAssistant:\n"


class BaseDatasetProcessor:
    dataset_name = None

    def __init__(self, num_proc=4, sanity_check=False, prompt_template=DEFAULT_PROMPT_TEMPLATE):
        self.num_proc = num_proc
        self.sanity_check = sanity_check
        self.prompt_template = prompt_template

    def get_preference_dataset(self, split, seed, removed_dimensions = None):
        raise NotImplementedError("Subclasses should implement this method.")

    def get_raw_dataset(self, split, seed):
        # If the dataset includes validation and testing splits, you will need to override this function accordingly
        dataset = load_dataset(self.dataset_name, split="train")
        if self.sanity_check:
            dataset = dataset.select(range(min(len(dataset), 10)))
        # Split into training and temporary (validation + test) sets
        dataset_split = dataset.train_test_split(test_size=0.2, seed=seed)

        # Extract the training and temporary subsets
        train_dataset = dataset_split["train"]
        temp_dataset = dataset_split["test"]

        # Further split the temporary set into validation and test sets
        temp_split = temp_dataset.train_test_split(test_size=0.5, seed=seed)  # 0.5 * 0.2 = 0.1

        # Extract the validation and test subsets
        val_dataset = temp_split["train"]
        test_dataset = temp_split["test"]
        if split == "train":
            return train_dataset
        elif split == "validation":
            return val_dataset
        elif split == "test":
            return test_dataset
        else:
            NotImplementedError

    def select_and_rename_columns(self, dataset, schema):
        """
           Rename columns in the dataset based on the provided mapping dictionary.

           Parameters:
               example (dict): A dictionary of the current row in the dataset.
               columns_to_keep (dict): A dictionary mapping old column names to new names.

           Returns:
               dict: The example with columns renamed.
           """
        # Rename columns based on schema
        for old_name, new_name in schema.items():
            if (old_name in dataset.column_names) and (old_name != new_name):
                dataset = dataset.rename_column(old_name, new_name)

        # Remove columns that are not in the schema
        dataset = dataset.remove_columns([col for col in dataset.column_names if col not in schema.values()])

        return dataset